{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# ZhongJingGPT-2-1.8b\n",
        "\n",
        "A Traditional Chinese Medicine large language model, inspired by the wisdom of the eminent representative of ancient Chinese medical scholars, Zhang Zhongjing. This model aims to illuminate the profound knowledge of Traditional Chinese Medicine, bridging the gap between ancient wisdom and modern technology, and providing a reliable and professional tool for the Traditional Chinese Medical fields. However, all generated results are for reference only and should be provided by experienced professionals for diagnosis and treatment results and suggestions."
      ],
      "metadata": {
        "id": "NKOuDx1olwGw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "print(torch.cuda.is_available())\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JM-BsUrpeWJT",
        "outputId": "8d593699-2995-452d-c8be-936fafa0249e"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "True\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install transformers huggingface_hub accelerate peft"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WFDppZ3udxdz",
        "outputId": "eac02119-90b6-478d-eb81-eba0f0232491"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.37.2)\n",
            "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.20.3)\n",
            "Collecting accelerate\n",
            "  Downloading accelerate-0.27.2-py3-none-any.whl (279 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m280.0/280.0 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting peft\n",
            "  Downloading peft-0.8.2-py3-none-any.whl (183 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m183.4/183.4 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
            "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.2)\n",
            "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.2)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.2)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2023.6.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.9.0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n",
            "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu121)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.3)\n",
            "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n",
            "Installing collected packages: accelerate, peft\n",
            "Successfully installed accelerate-0.27.2 peft-0.8.2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# You should restart colab and the run the following code."
      ],
      "metadata": {
        "id": "PNSjkWCqmM8V"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "import torch\n",
        "\n",
        "# Set the device\n",
        "device = \"cuda\"  # replace with your device: \"cpu\", \"cuda\", \"mps\"\n",
        "\n",
        "# Initialize model and tokenizer\n",
        "peft_model_id = \"CMLL/ZhongJing-2-1_8b\"\n",
        "base_model_id = \"Qwen/Qwen1.5-1.8B-Chat\"\n",
        "model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map=\"auto\")\n",
        "model.load_adapter(peft_model_id)\n",
        "tokenizer = AutoTokenizer.from_pretrained(\n",
        "    \"CMLL/ZhongJing-2-1_8b\",\n",
        "    padding_side=\"right\",\n",
        "    trust_remote_code=True,\n",
        "    pad_token=''\n",
        ")\n",
        "\n",
        "def get_model_response(question, context):\n",
        "    # Create the prompt\n",
        "    prompt = f\"Question: {question}\\nContext: {context}\"\n",
        "    messages = [\n",
        "        {\"role\": \"system\", \"content\": \"You are a helpful TCM assistant named 仲景中医大语言模型.\"},\n",
        "        {\"role\": \"user\", \"content\": prompt}\n",
        "    ]\n",
        "\n",
        "    # Prepare the input\n",
        "    text = tokenizer.apply_chat_template(\n",
        "        messages,\n",
        "        tokenize=False,\n",
        "        add_generation_prompt=True\n",
        "    )\n",
        "    model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
        "\n",
        "    # Generate the response\n",
        "    generated_ids = model.generate(\n",
        "        model_inputs.input_ids,\n",
        "        max_new_tokens=512\n",
        "    )\n",
        "    generated_ids = [\n",
        "        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
        "    ]\n",
        "\n",
        "    # Decode the response\n",
        "    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
        "    return response\n",
        "\n",
        "# Loop to get user input and provide model response\n",
        "while True:\n",
        "    user_question = input(\"Enter your question (or type 'exit' to stop): \")\n",
        "    if user_question.lower() == 'exit':\n",
        "        break\n",
        "    user_context = input(\"Enter context (or type 'none' if no context): \")\n",
        "    if user_context.lower() == 'none':\n",
        "        user_context = \"\"\n",
        "\n",
        "    print(\"Model is generating a response, please wait...\")\n",
        "    model_response = get_model_response(user_question, user_context)\n",
        "    print(\"Model's response:\", model_response)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "jsn4szdjdtmF",
        "outputId": "900e42e2-23be-4fb3-91d7-b586ba2d18a5"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n",
            "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
            "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Enter your question (or type 'exit' to stop): 你是谁\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: 我是“仲景中医大语言模型”，也叫“仲景大医”或“大医大”。\n",
            "Enter your question (or type 'exit' to stop): 我发热，咳嗽，呼吸困难，给出中医诊断和处方\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: “发高热，恶寒，咳嗽，喘促不能平卧者”属于“肺热壅盛”，患者应该根据自身病情采取中西医结合的方法来进行治疗。中医常采用清热解毒的药物，如白虎汤、黄连阿胶汤、银翘散等；西药常用布洛芬、对乙酰氨基酚等进行退烧。当症状缓解时，可以使用疏风清热、宣肺降气的中药方剂，如麻杏石甘汤或加减葳蕤汤来治疗。\n",
            "Enter your question (or type 'exit' to stop): 我还咳嗽，痰黄，该怎么办\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: 患者出现咳喘症状时，应及时就医。根据不同的病情和病因采取相应措施，如止咳、化痰、解表等。\n",
            "1、如果患者有发热等全身不适的表现，则可以使用消炎药、抗病毒药物进行治疗；\n",
            "2、如果患者是因感冒、肺炎导致的咳喘，则可以服用一些清热、解毒、止咳、化痰的中药，或者给予雾化吸入治疗；\n",
            "3、如果患者有咳嗽的症状，并伴有痰黄，可以在医生的指导下应用甘草合剂、川贝枇杷露、鲜竹沥液等药物进行治疗；\n",
            "4、如果是外感风寒所致的咳喘，则可以使用生姜、葱白、红糖煎水进行治疗，同时还可以应用疏风散寒、温肺化饮的中药。\n",
            "总之，在发生咳嗽的同时，还需结合具体的病因、临床表现和伴随的其他相关症状来综合分析，才能准确评估患者的病情并及时进行相应的治疗。\n",
            "Enter your question (or type 'exit' to stop): 谢谢\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: 您好，非常感谢您的提问。\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "Interrupted by user",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-1-bd9c00e18996>\u001b[0m in \u001b[0;36m<cell line: 49>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     48\u001b[0m \u001b[0;31m# Loop to get user input and provide model response\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m     \u001b[0muser_question\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Enter your question (or type 'exit' to stop): \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     51\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0muser_question\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'exit'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m         \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mraw_input\u001b[0;34m(self, prompt)\u001b[0m\n\u001b[1;32m    849\u001b[0m                 \u001b[0;34m\"raw_input was called, but this frontend does not support input requests.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    850\u001b[0m             )\n\u001b[0;32m--> 851\u001b[0;31m         return self._input_request(str(prompt),\n\u001b[0m\u001b[1;32m    852\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_ident\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    853\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_header\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36m_input_request\u001b[0;34m(self, prompt, ident, parent, password)\u001b[0m\n\u001b[1;32m    893\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    894\u001b[0m                 \u001b[0;31m# re-raise KeyboardInterrupt, to truncate traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 895\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Interrupted by user\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    896\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    897\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarning\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Invalid Message:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_info\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: Interrupted by user"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Q8pwg9UlitWI"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}
